/*
 * AIEngine a new generation network intrusion detection system.
 *
 * Copyright (C) 2013-2023  Luis Campo Giralte
 *
 * This library is free software; you can redistribute it and/or
 * modify it under the terms of the GNU Library General Public
 * License as published by the Free Software Foundation; either
 * version 2 of the License, or (at your option) any later version.
 *
 * This library is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * Library General Public License for more details.
 *
 * You should have received a copy of the GNU Library General Public
 * License along with this library; if not, write to the
 * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
 * Boston, MA  02110-1301, USA.
 *
 * Written by Luis Campo Giralte <luis.camp0.2009@gmail.com>
 *
 */
#include "NetworkStack.h"

namespace aiengine {

NetworkStack::NetworkStack() {

        // Connect the Protocols with the Multiplexers
        eth->setMultiplexer(mux_eth);
        mux_eth->setProtocol(static_cast<ProtocolPtr>(eth));

        vlan->setMultiplexer(mux_vlan);
        mux_vlan->setProtocol(static_cast<ProtocolPtr>(vlan));

        mpls->setMultiplexer(mux_mpls);
        mux_mpls->setProtocol(static_cast<ProtocolPtr>(mpls));

        pppoe->setMultiplexer(mux_pppoe);
        mux_pppoe->setProtocol(static_cast<ProtocolPtr>(pppoe));

	// Connect the protocols with the FlowForwarders
        http->setFlowForwarder(ff_http);
        ff_http->setProtocol(static_cast<ProtocolPtr>(http));

        ssl->setFlowForwarder(ff_ssl);
        ff_ssl->setProtocol(static_cast<ProtocolPtr>(ssl));

        dns->setFlowForwarder(ff_dns);
        ff_dns->setProtocol(static_cast<ProtocolPtr>(dns));

        sip->setFlowForwarder(ff_sip);
        ff_sip->setProtocol(static_cast<ProtocolPtr>(sip));

        dhcp->setFlowForwarder(ff_dhcp);
        ff_dhcp->setProtocol(static_cast<ProtocolPtr>(dhcp));

        ntp->setFlowForwarder(ff_ntp);
        ff_ntp->setProtocol(static_cast<ProtocolPtr>(ntp));

        snmp->setFlowForwarder(ff_snmp);
        ff_snmp->setProtocol(static_cast<ProtocolPtr>(snmp));

        ssdp->setFlowForwarder(ff_ssdp);
        ff_ssdp->setProtocol(static_cast<ProtocolPtr>(ssdp));

        netbios->setFlowForwarder(ff_netbios);
        ff_netbios->setProtocol(static_cast<ProtocolPtr>(netbios));

        quic->setFlowForwarder(ff_quic);
        ff_quic->setProtocol(static_cast<ProtocolPtr>(quic));

        coap->setFlowForwarder(ff_coap);
        ff_coap->setProtocol(static_cast<ProtocolPtr>(coap));

        rtp->setFlowForwarder(ff_rtp);
        ff_rtp->setProtocol(static_cast<ProtocolPtr>(rtp));

        smtp->setFlowForwarder(ff_smtp);
        ff_smtp->setProtocol(static_cast<ProtocolPtr>(smtp));

        imap->setFlowForwarder(ff_imap);
        ff_imap->setProtocol(static_cast<ProtocolPtr>(imap));

        pop->setFlowForwarder(ff_pop);
        ff_pop->setProtocol(static_cast<ProtocolPtr>(pop));

        bitcoin->setFlowForwarder(ff_bitcoin);
        ff_bitcoin->setProtocol(static_cast<ProtocolPtr>(bitcoin));

        modbus->setFlowForwarder(ff_modbus);
        ff_modbus->setProtocol(static_cast<ProtocolPtr>(modbus));

        mqtt->setFlowForwarder(ff_mqtt);
        ff_mqtt->setProtocol(static_cast<ProtocolPtr>(mqtt));

        smb->setFlowForwarder(ff_smb);
        ff_smb->setProtocol(static_cast<ProtocolPtr>(smb));

        ssh->setFlowForwarder(ff_ssh);
        ff_ssh->setProtocol(static_cast<ProtocolPtr>(ssh));

        dcerpc->setFlowForwarder(ff_dcerpc);
        ff_dcerpc->setProtocol(static_cast<ProtocolPtr>(dcerpc));

        dtls->setFlowForwarder(ff_dtls);
        ff_dtls->setProtocol(static_cast<ProtocolPtr>(dtls));

        tcp_generic->setFlowForwarder(ff_tcp_generic);
        ff_tcp_generic->setProtocol(static_cast<ProtocolPtr>(tcp_generic));

        udp_generic->setFlowForwarder(ff_udp_generic);
        ff_udp_generic->setProtocol(static_cast<ProtocolPtr>(udp_generic));

        freqs_tcp->setFlowForwarder(ff_tcp_freqs);
        ff_tcp_freqs->setProtocol(static_cast<ProtocolPtr>(freqs_tcp));

        freqs_udp->setFlowForwarder(ff_udp_freqs);
        ff_udp_freqs->setProtocol(static_cast<ProtocolPtr>(freqs_udp));

	// Sets the AnomalyManager on protocols that could generate an anomaly
        dns->setAnomalyManager(anomaly_);
        snmp->setAnomalyManager(anomaly_);
        coap->setAnomalyManager(anomaly_);
        rtp->setAnomalyManager(anomaly_);
        sip->setAnomalyManager(anomaly_);
        http->setAnomalyManager(anomaly_);
        ssl->setAnomalyManager(anomaly_);
        smtp->setAnomalyManager(anomaly_);
        pop->setAnomalyManager(anomaly_);
        imap->setAnomalyManager(anomaly_);
        mqtt->setAnomalyManager(anomaly_);
        netbios->setAnomalyManager(anomaly_);
        dhcp->setAnomalyManager(anomaly_);
}

NetworkStack::~NetworkStack() {

	name_.clear();
        tcp_regex_mng_.reset();
        udp_regex_mng_.reset();
        tcp_ipset_mng_.reset();
        udp_ipset_mng_.reset();
        ff_udp_current_.reset();
        ff_tcp_current_.reset();
}

ProtocolPtr NetworkStack::get_protocol(const std::string &name) const {

	ProtocolPtr pp;

	for (auto &p: proto_vector_) {
		ProtocolPtr proto = p.second;

		if (boost::iequals(name, proto->name())) {
			pp = proto;
			break;
		}
	}
	return pp;
}

void NetworkStack::addProtocol(ProtocolPtr proto) {

	ProtocolPair pp(proto->name(), proto);

	proto_vector_.push_back(pp);
}

void NetworkStack::addProtocol(ProtocolPtr proto, bool active) {

	proto->active(active);
	addProtocol(proto);
}

int64_t NetworkStack::getAllocatedMemory() const {

	int64_t value = 0;

	for (auto &p: proto_vector_)
		value += (p.second)->getAllocatedMemory();

	return value;
}

int64_t NetworkStack::getTotalAllocatedMemory() const {

	int64_t value = 0;

	for (auto &p: proto_vector_)
		value += (p.second)->getTotalAllocatedMemory();

	return value;
}

void NetworkStack::statistics(const std::string &name, int level) const {

	statistics(OutputManager::getInstance()->out(), name, level, std::numeric_limits<int>::max());
}

void NetworkStack::statistics(std::basic_ostream<char> &out, const std::string &name, int level, int32_t limit) const {

        if (level > 0) {
                ProtocolPtr proto = get_protocol(name);

                if ((proto)and(proto->active())) {
                        proto->statistics(out, level, limit);
                        out << std::endl;
                }
        }
}

void NetworkStack::statistics(Json &out, const std::string &name, int level) const {

	statistics(out, name, level, std::numeric_limits<int>::max());
}

void NetworkStack::statistics(Json &out, const std::string &name, int level, int32_t limit) const {

        if (level > 0) {
                ProtocolPtr proto = get_protocol(name);

                if ((proto)and(proto->active()))
                        proto->statistics(out, level);
        }
}

void NetworkStack::statistics(const std::string &name) const {

	statistics(name, stats_level_);
}

void NetworkStack::statistics(std::basic_ostream<char> &out, int level, int32_t limit) const {

	if (level > 0) {
		std::for_each (proto_vector_.begin(), proto_vector_.end(), [&] (ProtocolPair const &pp) {
			ProtocolPtr proto = pp.second;

			if (proto->active()) {
				proto->statistics(out, level, limit);
				out << std::endl;
			}
		});
	}
}

void NetworkStack::statistics(int level) const {

	statistics(OutputManager::getInstance()->out(), level, std::numeric_limits<int>::max());
}

void NetworkStack::setStatisticsLevel(int level) {

        stats_level_ = level;

	std::for_each (proto_vector_.begin(), proto_vector_.end(), [&] (ProtocolPair const &pp) {
		ProtocolPtr proto = pp.second;

		proto->setStatisticsLevel(level);
	});
}

std::ostream& operator<< (std::ostream &out, const NetworkStack &ns) {

	ns.statistics(out, ns.stats_level_, std::numeric_limits<int>::max());

        return out;
}

// This method is only executed by users under the shell control
void NetworkStack::showFlows(int limit) const {

	UserFlowOptions options;

	options.limit = limit;

	show_selected_flows(OutputManager::getInstance()->out(), options);
}

void NetworkStack::showFlows(const std::string& protoname, int limit) const {

	UserFlowOptions options;

	options.limit = limit;
	options.l7protocol_name = protoname;

	show_selected_flows(OutputManager::getInstance()->out(), options);
}

void NetworkStack::showFlows(const std::string &protoname) const {

	UserFlowOptions options;

	options.l7protocol_name = protoname;

	show_selected_flows(OutputManager::getInstance()->out(), options);
}

void NetworkStack::showFlows() const {

	UserFlowOptions options;

	show_selected_flows(OutputManager::getInstance()->out(), options);
}

void NetworkStack::showFlows(std::basic_ostream<char> &out) const {

	UserFlowOptions options;

	show_selected_flows(out, options);
}

void NetworkStack::statistics() const {

	statistics(OutputManager::getInstance()->out());
}

void NetworkStack::statistics(std::basic_ostream<char> &out) const {

	out << *this;
}

void NetworkStack::setDomainNameManager(const SharedPointer<DomainNameManager> &dnm, const std::string &name) {

	setDomainNameManager(dnm, name, true);
}

void NetworkStack::setDomainNameManager(const SharedPointer<DomainNameManager> &dnm, const std::string &name, bool allow) {

        if (ProtocolPtr pp = get_protocol(name); pp) {
		if (allow)
			pp->setDomainNameManager(dnm);
		else
			pp->setDomainNameBanManager(dnm);
        }
}
#if defined(RUBY_BINDING) || defined(LUA_BINDING) || defined(JAVA_BINDING) || defined(GO_BINDING)

void NetworkStack::setDomainNameManager(const DomainNameManager &dnm, const std::string &name) {

	auto dm = std::make_shared<DomainNameManager>(dnm);

	setDomainNameManager(dm, name);
}

void NetworkStack::setDomainNameManager(const DomainNameManager &dnm, const std::string &name, bool allow) {

	auto dm = std::make_shared<DomainNameManager>(dnm);

	setDomainNameManager(dm, name, allow);
}

#endif

void NetworkStack::enable_protocol(const ProtocolPtr &proto, const SharedPointer<FlowForwarder> &ff) {

        auto f = proto->getFlowForwarder().lock();

        if ((ff)and(f)) {
                proto->active(true);
                ff->insertUpFlowForwarder(f);
        }
}

void NetworkStack::disable_protocol(const ProtocolPtr &proto, const SharedPointer<FlowForwarder> &ff) {

        if (auto f = proto->getFlowForwarder().lock(); ff and f) {
                proto->active(false);
                ff->removeUpFlowForwarder(f);
        }
}

#if defined(BINDING)

void NetworkStack::enableProtocol(const std::string &name) {

        ProtocolPtr proto = get_protocol(name);
        if ((proto) and (!proto->active())) {
                if (proto->getProtocolLayer() == IPPROTO_UDP) {
                        enable_protocol(proto, ff_udp_current_);
                } else if (proto->getProtocolLayer() == IPPROTO_TCP) {
                        enable_protocol(proto, ff_tcp_current_);
                }
        }
}

void NetworkStack::disableProtocol(const std::string &name) {

        ProtocolPtr proto = get_protocol(name);

        if ((proto) and (proto->active())) {
                if (proto->getProtocolLayer() == IPPROTO_UDP) {
                        disable_protocol(proto, ff_udp_current_);
                } else if (proto->getProtocolLayer() == IPPROTO_TCP) {
                        disable_protocol(proto, ff_tcp_current_);
                }
        }
}

#if defined(PYTHON_BINDING)
void NetworkStack::setUDPDatabaseAdaptor(boost::python::object &dbptr) {

	setUDPDatabaseAdaptor(dbptr, default_update_frequency);
}
#elif defined(RUBY_BINDING)
void NetworkStack::setUDPDatabaseAdaptor(VALUE dbptr) {

	setUDPDatabaseAdaptor(dbptr, default_update_frequency);
}
#elif defined(JAVA_BINDING) || defined(GO_BINDING)
void NetworkStack::setUDPDatabaseAdaptor(DatabaseAdaptor *dbptr) {

	setUDPDatabaseAdaptor(dbptr, default_update_frequency);
}
#elif defined(LUA_BINDING)
void NetworkStack::setUDPDatabaseAdaptor(lua_State *L, const char *obj_name) {

	setUDPDatabaseAdaptor(L, obj_name, default_update_frequency);
}
#endif

#if defined(PYTHON_BINDING)
void NetworkStack::setTCPDatabaseAdaptor(boost::python::object &dbptr) {

	setTCPDatabaseAdaptor(dbptr, default_update_frequency);
}
#elif defined(RUBY_BINDING)
void NetworkStack::setTCPDatabaseAdaptor(VALUE dbptr) {

	setTCPDatabaseAdaptor(dbptr, default_update_frequency);
}
#elif defined(JAVA_BINDING) || defined(GO_BINDING)
void NetworkStack::setTCPDatabaseAdaptor(DatabaseAdaptor *dbptr) {

	setTCPDatabaseAdaptor(dbptr, default_update_frequency);
}
#elif defined(LUA_BINDING)
void NetworkStack::setTCPDatabaseAdaptor(lua_State *L, const char* obj_name) {

	setTCPDatabaseAdaptor(L,obj_name, default_update_frequency);
}
#endif

#if defined(PYTHON_BINDING)
void NetworkStack::setUDPDatabaseAdaptor(boost::python::object &dbptr, int packet_sampling) {
#elif defined(RUBY_BINDING)
void NetworkStack::setUDPDatabaseAdaptor(VALUE dbptr, int packet_sampling) {
#elif defined(JAVA_BINDING) || defined(GO_BINDING)
void NetworkStack::setUDPDatabaseAdaptor(DatabaseAdaptor *dbptr, int packet_sampling) {
#elif defined(LUA_BINDING)
void NetworkStack::setUDPDatabaseAdaptor(lua_State *L, const char *obj_name, int packet_sampling) {
#endif
        if (ProtocolPtr pp = get_protocol(UDPProtocol::default_name); pp) {
                if (UDPProtocolPtr proto = std::static_pointer_cast<UDPProtocol>(pp); proto) {
#if defined(LUA_BINDING)
                        proto->setDatabaseAdaptor(L, obj_name, packet_sampling);
#else
                        proto->setDatabaseAdaptor(dbptr, packet_sampling);
#endif
                }
        }
}

#if defined(PYTHON_BINDING)
void NetworkStack::setTCPDatabaseAdaptor(boost::python::object &dbptr, int packet_sampling) {
#elif defined(RUBY_BINDING)
void NetworkStack::setTCPDatabaseAdaptor(VALUE dbptr, int packet_sampling) {
#elif defined(JAVA_BINDING) || defined(GO_BINDING)
void NetworkStack::setTCPDatabaseAdaptor(DatabaseAdaptor *dbptr, int packet_sampling) {
#elif defined(LUA_BINDING)
void NetworkStack::setTCPDatabaseAdaptor(lua_State *L, const char *obj_name, int packet_sampling) {
#endif
        if (ProtocolPtr pp = get_protocol(TCPProtocol::default_name); pp) {
                if (TCPProtocolPtr proto = std::static_pointer_cast<TCPProtocol>(pp); proto) {
#if defined(LUA_BINDING)
                        proto->setDatabaseAdaptor(L, obj_name, packet_sampling);
#else
                        proto->setDatabaseAdaptor(dbptr, packet_sampling);
#endif
                }
        }
}

#if defined(PYTHON_BINDING)
void NetworkStack::setAnomalyCallback(PyObject *callback, const std::string &proto_name) {
#elif defined(RUBY_BINDING)
void NetworkStack::setAnomalyCallback(VALUE callback, const std::string &proto_name) {
#elif defined(JAVA_BINDING)
void NetworkStack::setAnomalyCallback(JaiCallback *callback, const std::string &proto_name) {
#elif defined(LUA_BINDING)
void NetworkStack::setAnomalyCallback(lua_State *L, const std::string &callback, const std::string &proto_name) {
#elif defined(GO_BINDING)
void NetworkStack::setAnomalyCallback(GoaiCallback *callback, const std::string &proto_name) {
#endif
	if (anomaly_) {
#if defined(LUA_BINDING)
		anomaly_->setCallback(L, callback, proto_name);
#else
		anomaly_->setCallback(callback, proto_name);
#endif
	}
}

#endif

void NetworkStack::attachTo(const SharedPointer<Flow> &flow, const std::string &name) {

	if (ProtocolPtr dst_proto = get_protocol(name); dst_proto)
		if (auto dst_ff = dst_proto->getFlowForwarder().lock(); dst_ff)
			if (dst_proto->getProtocolLayer() == flow->getProtocol()) {

				// The flow can have a forwarder attached or not
				// depending if has been attached to a L7 protocol
				if (flow->forwarder.lock())
					if (auto src_proto = flow->forwarder.lock()->getProtocol(); src_proto)
						src_proto->releaseFlowInfo(flow.get());

				// Attach the flow to the new FlowForwarder
				flow->forwarder = dst_ff;

				// Remove the reference to any L7 object
				flow->layer7info.reset();
				flow->frequencies.reset();
				flow->packet_frequencies.reset();
			}
}

#if defined(PYTHON_BINDING)

void NetworkStack::setOnFailCacheCallback(const std::string &name, PyObject *callback) {

        if (ProtocolPtr pp = get_protocol(name); pp)
		pp->setOnFailCacheCallback(callback);
}

boost::python::dict NetworkStack::getCounters(const std::string &name) {
	boost::python::dict counters;

        if (ProtocolPtr pp = get_protocol(name); pp) {
		CounterMap cm = pp->getCounters();
        	counters = cm.getRawCounters();
        }

        return counters;
}

boost::python::dict NetworkStack::getCacheData(const std::string &protocol, const std::string &name) {
        boost::python::dict cache;

        if (ProtocolPtr pp = get_protocol(protocol); pp)
                cache = pp->getCacheData(name);

        return cache;
}

SharedPointer<Cache<StringCache>> NetworkStack::getCache(const std::string &protocol, const std::string &name) {
	SharedPointer<Cache<StringCache>> cache = nullptr;

        if (ProtocolPtr pp = get_protocol(protocol); pp)
		return pp->getCache(name);

	return cache;
}

#elif defined(RUBY_BINDING)

VALUE NetworkStack::getCounters(const std::string &name) {
	VALUE counters = Qnil;

	if (ProtocolPtr pp = get_protocol(name); pp) {
		CounterMap cm = pp->getCounters();
        	counters = cm.getRawCounters();
	}

	return counters;
}

VALUE NetworkStack::getCacheData(const std::string &protocol, const std::string &name) {
	VALUE cache = Qnil;

	if (ProtocolPtr pp = get_protocol(protocol); pp)
		cache = pp->getCacheData(name);

	return cache;
}

#elif defined(JAVA_BINDING) || defined(GO_BINDING)

std::map<std::string, int32_t> NetworkStack::getCounters(const std::string &name) {
	std::map<std::string, int32_t> counters;

        if (ProtocolPtr pp = get_protocol(name); pp) {
		CounterMap cm = pp->getCounters();
        	counters = cm.getRawCounters();
        }

	return counters;
}

#elif defined(LUA_BINDING)

std::map<std::string, int> NetworkStack::getCounters(const char *name) {
	std::map<std::string, int> counters;
	std::string sname(name);

        if (ProtocolPtr pp = get_protocol(sname); pp) {
		CounterMap cm = pp->getCounters();
        	counters = cm.getRawCounters();
	}
	return counters;
}

#endif

void NetworkStack::releaseCache(const std::string &name) {

	if (ProtocolPtr proto = get_protocol(name); proto)
        	proto->releaseCache();
}

void NetworkStack::releaseCaches() {

	std::for_each (proto_vector_.begin(), proto_vector_.end(), [&] (ProtocolPair const &pp) {
        	ProtocolPtr proto = pp.second;

                proto->releaseCache();
        });
}

void NetworkStack::enableFlowForwarders(const std::initializer_list<SharedPointer<FlowForwarder>> &ffs) {

	SharedPointer<FlowForwarder> head_ff = *(ffs.begin());

	for (auto f = ffs.begin() + 1; f != ffs.end(); ++f) {
		ProtocolPtr proto = (*f)->getProtocol();

                proto->active(true);
                head_ff->addUpFlowForwarder(proto->getFlowForwarder().lock());
	}
}

void NetworkStack::disableFlowForwarders(const std::initializer_list<SharedPointer<FlowForwarder>> &ffs) {

	SharedPointer<FlowForwarder> head_ff = *(ffs.begin());

	for (auto f = ffs.begin() + 1; f != ffs.end(); ++f) {
		ProtocolPtr proto = (*f)->getProtocol();

		disable_protocol(proto, head_ff);
	}
}

void NetworkStack::infoMessage(const std::string &msg) {

	aiengine::information_message(msg);
}

void NetworkStack::enableLinkLayerTagging(const std::string &type) {

	// set as unactive
	vlan->active(false);
	mpls->active(false);
	pppoe->active(false);

        if (type.compare("vlan") == 0) {
                mux_eth->addUpMultiplexer(mux_vlan);
                mux_vlan->addDownMultiplexer(mux_eth);
                mux_vlan->addUpMultiplexer(mux_ip);
                mux_ip->addDownMultiplexer(mux_vlan);
		link_layer_tag_name_ = type;
		vlan->active(true);
        } else if (type.compare("mpls") == 0) {
                mux_eth->addUpMultiplexer(mux_mpls);
                mux_mpls->addDownMultiplexer(mux_eth);
                mux_mpls->addUpMultiplexer(mux_ip);
                mux_ip->addDownMultiplexer(mux_mpls);
		link_layer_tag_name_ = type;
		mpls->active(true);
        } else if (type.compare("pppoe") == 0) {
                mux_eth->addUpMultiplexer(mux_pppoe);
                mux_pppoe->addDownMultiplexer(mux_eth);
                mux_pppoe->addUpMultiplexer(mux_ip);
                mux_ip->addDownMultiplexer(mux_pppoe);
		link_layer_tag_name_ = type;
		pppoe->active(true);
        } else {
                std::ostringstream msg;
                msg << "Unknown tagging type " << type;

                infoMessage(msg.str());
		link_layer_tag_name_ = "";
        }
}

void NetworkStack::increaseAllocatedMemory(const std::string &name, int value) {

        if (ProtocolPtr proto = get_protocol(name); proto) {
        	std::ostringstream msg;
                msg << "Increase allocated memory in " << value << " on " << name << " protocol";

                infoMessage(msg.str());

                proto->increaseAllocatedMemory(value);
        }
}

void NetworkStack::decreaseAllocatedMemory(const std::string &name,int value) {

        if (ProtocolPtr proto = get_protocol(name); proto) {
        	std::ostringstream msg;
                msg << "Decrease allocated memory in " << value << " on " << name << " protocol";

                infoMessage(msg.str());

                proto->decreaseAllocatedMemory(value);
        }
}

void NetworkStack::setDynamicAllocatedMemory(const std::string &name, bool value) {

	if (ProtocolPtr proto = get_protocol(name); proto)
		proto->setDynamicAllocatedMemory(value);
}

void NetworkStack::setDynamicAllocatedMemory(bool value) {

	std::for_each (proto_vector_.begin(), proto_vector_.end(), [&] (ProtocolPair const &pp) {
        	ProtocolPtr proto = pp.second;

                proto->setDynamicAllocatedMemory(value);
        });
}

#if defined(JAVA_BINDING)

void NetworkStack::setTCPRegexManager(RegexManager *sig) {

	if (sig == nullptr) {
		tcp_regex_mng_.reset();
	} else {
		SharedPointer<RegexManager> rm(sig);

		setTCPRegexManager(rm);
	}
}

void NetworkStack::setUDPRegexManager(RegexManager *sig) {

	if (sig == nullptr) {
		udp_regex_mng_.reset();
	} else {
		SharedPointer<RegexManager> rm(sig);

		setUDPRegexManager(rm);
	}
}

#endif

void NetworkStack::showProtocolSummary(std::basic_ostream<char> &out) const {

	const char *header = "%-14s %-14s %-12s %-8s %-10s %-14s %-14s %-14s %-14s %-8s %-10s";
	const char *format = "%-14s %-14d %-12d %-8d %-10d %-14d %-14s %-14s %-14s %-8s %-10s";
	uint64_t total_packets = 0;
	uint64_t total_bytes = 0;
	uint64_t total_memory = 0;
	uint64_t total_used_memory = 0;
	uint64_t total_map_memory = 0;
	uint32_t total_cmiss = 0;
	uint32_t total_events = 0;
	uint64_t total_cycles = 0;

        ProtocolPtr proto = get_protocol("Ethernet");
        if (proto) {
		total_packets = proto->getTotalPackets();
		total_bytes = proto->getTotalBytes();
	}

	out << "Protocol statistics summary (" << name() << ")\n"
		<< "\t" << boost::format(header) % "Protocol" % "Bytes" % "Packets" % "% Bytes" % "CacheMiss" % "CpuCycles" % "Memory" % "UseMemory" % "CacheMemory" % "Dynamic" % "Events";
	out << "\n";

	for (auto &&pp: proto_vector_) {
		ProtocolPtr proto = pp.second;

		if (!proto->active())
			continue;

		std::string name = proto->name();;
		uint64_t packets = proto->getTotalPackets();
		uint64_t bytes = proto->getTotalBytes();
		uint32_t cmiss = proto->getTotalCacheMisses();

		int64_t per = 0;
		if (total_bytes > 0)
			per = ( bytes * 100.00) / total_bytes;

		std::string_view dynamic_mem = proto->isDynamicAllocatedMemory() ? "yes": "no";

		std::string unit = "Bytes";
		std::string used_unit = "Bytes";
		std::string map_unit = "Bytes";
		uint64_t memory = proto->getTotalAllocatedMemory();
		uint64_t map_memory = memory - proto->getAllocatedMemory();
		uint64_t used_memory = proto->getCurrentUseMemory();
		uint32_t events = proto->getTotalEvents();
		uint64_t cycles = proto->getTotalCpuCycles();

		total_events += events;
		total_cmiss += cmiss;
		total_memory += memory;
		total_used_memory += used_memory;
		total_map_memory += map_memory;
		total_cycles += cycles;

		unitConverter(memory, unit);
		unitConverter(used_memory, used_unit);
		unitConverter(map_memory, map_unit);

		std::ostringstream s_mem;
		s_mem << memory << " " << unit;

		std::ostringstream s_used_mem;
		s_used_mem << used_memory << " " << used_unit;

		std::ostringstream s_map_mem;
		s_map_mem << map_memory << " " << map_unit;

		out << "\t" << boost::format(format) % name % bytes % packets % per % cmiss % cycles % s_mem.str() % s_used_mem.str() % s_map_mem.str() % dynamic_mem % events;
		out << "\n";
	}
	// The Total
	std::string unit = "Bytes";
	std::string used_unit = "Bytes";
	std::string map_unit = "Bytes";

	unitConverter(total_memory, unit);
	unitConverter(total_used_memory, used_unit);
	unitConverter(total_map_memory, map_unit);

	std::ostringstream s_mem;
	s_mem << total_memory << " " << unit;

	std::ostringstream s_used_mem;
	s_used_mem << total_used_memory << " " << used_unit;

	std::ostringstream s_map_mem;
	s_map_mem << total_map_memory << " " << map_unit;

	out << "\t" << boost::format(format) % "Total" % total_bytes % total_packets % 100 % total_cmiss % total_cycles % s_mem.str() % s_used_mem.str() % s_map_mem.str() % "" % total_events;
	out << "\n" << std::endl;
}

void NetworkStack::showProtocolSummary(Json &out) const {

	out = nlohmann::json::array();
        for (auto &&pp: proto_vector_) {
		Json item;
                ProtocolPtr proto = pp.second;

                if (!proto->active())
			continue;

		item["name"] = proto->name();
		item["packets"] = proto->getTotalPackets();
		item["bytes"] = proto->getTotalBytes();
		item["miss"] = proto->getTotalCacheMisses();

         	int64_t memory = proto->getTotalAllocatedMemory();
                int64_t map_memory = memory - proto->getAllocatedMemory();
                int64_t used_memory = proto->getCurrentUseMemory();
                int32_t events = proto->getTotalEvents();

		item["memory"] = memory;
		item["used_memory"] = used_memory;
		item["cache_memory"] = map_memory;
		item["events"] = events;
		item["cpu_cycles"] = proto->getTotalCpuCycles();

		out.push_back(item);
	}
}

void NetworkStack::statistics(Json &out, const std::string &name, const std::string &map_name, int32_t limit) const {

	if (ProtocolPtr proto = get_protocol(name); proto)
		proto->statistics(out, map_name, limit);
}

void NetworkStack::resetCounters(const std::string &name) {

        if (ProtocolPtr pp = get_protocol(name); pp)
		pp->resetCounters();
}

void NetworkStack::showFlows(std::basic_ostream<char> &out, const UserFlowOptions &options) const {

        show_selected_flows(out, options);
}

void NetworkStack::showFlows(Json &out, const UserFlowOptions &options) const {

        show_selected_flows(out, options);
}

// This is just a function for help when the user use ip/mask
// so the function parse and generates the nets and masks
std::tuple<uint32_t, uint32_t> get_ipv4_network_mask(const std::string &input) {

	std::string user_ip;
        in_addr in_netip;
        int prefix_len = 32;
        std::size_t found = input.find("/");
	uint32_t netmask = 0;
	uint32_t netip = 0;

        if (found != std::string::npos) {
		prefix_len = std::atoi(input.substr(found + 1, input.length()).c_str());
		if ((prefix_len > 32)or(prefix_len == 0))
			prefix_len = 32;
		user_ip = input.substr(0, found);
        } else {
                user_ip = input;
        }

        if (inet_aton(user_ip.c_str(), &in_netip) != 0) {
                int shift = 32 - prefix_len;
                netmask = ~((1 << shift) - 1);
                netip = ntohl(in_netip.s_addr);
	}
	return {netip, netmask};
}

template<typename T>
void NetworkStack::show_selected_flows(T &out, const UserFlowOptions &options) const {

        int current_limit = 0;
	uint32_t src_netip = 0;
	uint32_t src_netmask = 0;
	uint32_t dst_netip = 0;
	uint32_t dst_netmask = 0;

        // Sets all the functions to return true
        std::function<bool (const Flow&)> ipdst_condition = [&] (const Flow& flow) -> bool { return true; };
        std::function<bool (const Flow&)> ipsrc_condition = [&] (const Flow& flow) -> bool { return true; };
        std::function<bool (const Flow&)> portdst_condition = [&] (const Flow& flow) -> bool { return true; };
        std::function<bool (const Flow&)> portsrc_condition = [&] (const Flow& flow) -> bool { return true; };
        std::function<bool (const Flow&)> l7_condition = [&] (const Flow& flow) -> bool { return true; };

        std::function<bool (const Flow&)> limit_condition = [&] (const Flow& flow) -> bool {
                if (current_limit < options.limit) {
                        ++current_limit;
                        return true;
                }
                return false;
        };

        if (options.l7protocol_name.length() > 0) {
                std::function<bool (const Flow&)> condition = [&] (const Flow& flow) -> bool {
			if (boost::iequals(options.l7protocol_name, flow.getL7ProtocolName()))
				return true;

                        return false;
                };
                l7_condition = condition;
        }

        if (options.portsrc > 0) {
                std::function<bool (const Flow&)> condition = [&] (const Flow& flow) -> bool {
                        if (flow.getSourcePort() == options.portsrc)
                                return true;

                        return false;
                };
                portsrc_condition = condition;
        }

        if (options.portdst > 0) {
                std::function<bool (const Flow&)> condition = [&] (const Flow& flow) -> bool {
                        if (flow.getDestinationPort() == options.portdst)
                                return true;

                        return false;
                };
                portdst_condition = condition;
        }

        if (options.ipsrc.length() > 0) {
		std::tuple<uint32_t, uint32_t> src_items = get_ipv4_network_mask(options.ipsrc);

		src_netip = std::get<0>(src_items);
		src_netmask = std::get<1>(src_items);

                std::function<bool (const Flow&)> condition = [&] (const Flow& flow) -> bool {
			if (flow.isIPv4()) {
				uint32_t srcip = ntohl(flow.getSourceAddress());
				if ((srcip & src_netmask) == (src_netip & src_netmask))
					return true;
			} else
				if (options.ipsrc.compare(flow.getSrcAddrDotNotation()) == 0)
					return true;
                        return false;
                };
                ipsrc_condition = condition;
        }

        if (options.ipdst.length() > 0) {
		std::tuple<uint32_t, uint32_t> dst_items = get_ipv4_network_mask(options.ipdst);

		dst_netip = std::get<0>(dst_items);
		dst_netmask = std::get<1>(dst_items);

                std::function<bool (const Flow&)> condition = [&] (const Flow& flow) -> bool {
			if (flow.isIPv4()) {
				uint32_t dstip = ntohl(flow.getDestinationAddress());
				if ((dstip & dst_netmask) == (dst_netip & dst_netmask))
					return true;
			} else
				if (options.ipdst.compare(flow.getDstAddrDotNotation()) == 0)
					return true;
                        return false;
                };
                ipdst_condition = condition;
        }

        // Make the final condition with all the previous ones selected
        std::function<bool (const Flow&)> condition = [&] (const Flow& flow) -> bool {
                if (l7_condition(flow) and portsrc_condition(flow) and portdst_condition(flow))
                        if (ipsrc_condition(flow) and (ipdst_condition(flow)))
                                if (limit_condition(flow))
                                        return true;
                return false;
        };

        // Call the specific member function of the stack
        showFlows(out, condition, options.protocol);
}

#if defined(PYTHON_BINDING)

void NetworkStack::showSelectedFlows(const std::string &ipsrc, int portsrc, int protocol,
	const std::string &l7protoname, const std::string &ipdst, int portdst, int limit) const {

	UserFlowOptions options;

	options.ipsrc = ipsrc;
	options.portsrc = portsrc;
	options.protocol = protocol;
	options.l7protocol_name = l7protoname;
	options.ipdst = ipdst;
	options.portdst = portdst;
	options.limit = limit;

	show_selected_flows(OutputManager::getInstance()->out(), options);
}

#endif

} // namespace aiengine
